import os
import json
from typing import Dict, List
from .prompts import *



class MergeSolver:

    def __init__(self, goal: str, checkpoint: str):
        self.goal = goal
        self.checkpoint = checkpoint
        self.load_checkpoint()

    def solve(self):
        self.alphabet_mapping = self.construct_alphabet(self.goal)
        # self.automata_mapping = self.construct_automata(self.goal)

    def construct_alphabet(self):
        main_goal_msg = self._construct_main_goal_message()
        json_resp = self._get_json_resp(main_goal_msg)
        main_goal = json_resp.keys()[0]

        self.alphabet_mapping[main_goal] = {"description": json_resp[main_goal], "level": 0}
        self._construct_alphabet(main_goal, level=0)
    
    def _construct_alphabet(self, goal: str, level: int=0):
        subgoals_msg = self._construct_alphabet_message(goal, level=level)
        json_resp = self._get_json_resp(subgoals_msg)
        self.alphabet_mapping[goal]['decomposition'] = []
        for subgoal in json_resp.keys():
            if subgoal in self.alphabet_mapping:
                continue
            self.alphabet_mapping[goal]['decomposition'].append(subgoal)
            self.alphabet_mapping[subgoal] = {"description": self.alphabet_mapping[subgoal], "level": level + 1}
        if level >= 2:
            return 
        for subgoal in json_resp.keys():
            self._construct_alphabet(subgoal, level=level+1)

    def _get_skills_by_level(self, level: int):
        skills = []
        for k, v in self.alphabet_mapping.items():
            if v['level'] == level:
                skills.append(k)
        return skills

    def _construct_main_goal_message(self):
        return [
            {"role": "system", "content": MAIN_GOAL_SYS_PROMPT},
            {"role": "user", "content": MAIN_GOAL_USER_PROMPT.format(self.goal)}
        ]

    def _construct_alphabet_message(self, goal: str, level: int=0):
        if level < 2:
            return [
                {"role": "system", "content": SUBGOAL_SYS_PROMPT},
                {"role": "user", "content": SUBGOAL_USER_PROMPT.format(goal, self._get_skills_by_level(0))}
            ]
        else:
            return [
                {"role": "system", "content": SUBGOAL_BASE_SYS_PROMPT},
                {"role": "user", "content": SUBGOAL_BASE_USER_PROMPT.format(goal, self._get_skills_by_level(1))}
            ]

    def _get_json_resp(self, queries, seed: int=0, retry: int=10, max_tokens=4096):
        for i in range(retry):
            try:
                resp = self._get_openai_resp(queries, seed=seed + i, max_tokens=max_tokens, response_format={"type": "json_object"})
                result = json.loads(resp)
                return result
            except json.decoder.JSONDecodeError:
                print(resp)
        return None

    def _get_openai_resp(self, queries: List[Dict], max_tokens=4096, n=1, seed=0, response_format=None):
        response = self.client.chat.completions.create(
            model=self.model_name, # Specify the GPT-4 engine
            response_format=response_format, # {"type": "json_object"}
            messages=queries,
            max_tokens=max_tokens, # Maximum number of tokens in the response
            n=n, # Number of completions to generate
            stop=None, # Token at which to stop generating further tokens
            temperature=None, # Controls the randomness of the response
            seed=seed
        )
        return response.choices[0].message.content
    
    def load_checkpoint(self):
        if not os.path.exists(self.checkpoint):
            os.mkdir(self.checkpoint)
        else:
            if not os.path.exists(os.path.join(self.checkpoint, "llm_query_cache.json")):
                self.llm_query_cache = {}
            else:
                self.llm_query_cache = json.load(open(os.path.join(self.checkpoint, "llm_query_cache.json"), "r"))
            if not os.path.exists(os.path.join(self.checkpoint, "alphabet_mapping.json")):
                self.alphabet_mapping = {}
            else:
                self.alphabet_mapping = json.load(open(os.path.join(self.checkpoint, "alphabet_mapping.json"), "r"))
            if not os.path.exists(os.path.join(self.checkpoint, "automata_mapping.json")):
                self.automata_mapping = {}
            else:
                self.automata_mapping = json.load(open(os.path.join(self.checkpoint, "automata_mapping.json"), "r"))

    def save_checkpoint(self):
        if not os.path.exists(self.checkpoint):
            os.mkdir(self.checkpoint)
        json.dump(self.llm_query_cache, open(os.path.join(self.checkpoint, "llm_query_cache.json"), "w"))
        json.dump(self.alphabet_mapping, open(os.path.join(self.checkpoint, "alphabet_mapping.json"), "w"))
        json.dump(self.automata_mapping, open(os.path.join(self.checkpoint, "automata_mapping.json"), "w"))
